{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 228
    },
    "colab_type": "code",
    "id": "lIYdn1woOS1n",
    "outputId": "05f43a3e-f111-4f96-ee3e-d95027c041c8"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "import torchtext\n",
    "import torchtext.experimental\n",
    "import torchtext.experimental.vectors\n",
    "from torchtext.experimental.datasets.raw.text_classification import RawTextIterableDataset\n",
    "from torchtext.experimental.datasets.text_classification import TextClassificationDataset\n",
    "from torchtext.experimental.functional import sequential_transforms, vocab_func, totensor\n",
    "\n",
    "import collections\n",
    "import random\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kjHAEB8BKbEY"
   },
   "outputs": [],
   "source": [
    "seed = 1234\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "random.seed(seed)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HRkCva2fJ_kr"
   },
   "outputs": [],
   "source": [
    "raw_train_data, raw_test_data = torchtext.experimental.datasets.raw.IMDB()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "RkgVHXXSKAyU"
   },
   "outputs": [],
   "source": [
    "def get_train_valid_split(raw_train_data, split_ratio = 0.7):\n",
    "\n",
    "    raw_train_data = list(raw_train_data)\n",
    "        \n",
    "    random.shuffle(raw_train_data)\n",
    "        \n",
    "    n_train_examples = int(len(raw_train_data) * split_ratio)\n",
    "        \n",
    "    train_data = raw_train_data[:n_train_examples]\n",
    "    valid_data = raw_train_data[n_train_examples:]\n",
    "    \n",
    "    train_data = RawTextIterableDataset(train_data)\n",
    "    valid_data = RawTextIterableDataset(valid_data)\n",
    "    \n",
    "    return train_data, valid_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "T5fGSB1OKC77"
   },
   "outputs": [],
   "source": [
    "raw_train_data, raw_valid_data = get_train_valid_split(raw_train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "zvcEouXQLmHz"
   },
   "outputs": [],
   "source": [
    "class Tokenizer:\n",
    "    def __init__(self, tokenize_fn = 'basic_english', lower = True, max_length = None):\n",
    "        \n",
    "        self.tokenize_fn = torchtext.data.utils.get_tokenizer(tokenize_fn)\n",
    "        self.lower = lower\n",
    "        self.max_length = max_length\n",
    "        \n",
    "    def tokenize(self, s):\n",
    "        \n",
    "        tokens = self.tokenize_fn(s)\n",
    "        \n",
    "        if self.lower:\n",
    "            tokens = [token.lower() for token in tokens]\n",
    "            \n",
    "        if self.max_length is not None:\n",
    "            tokens = tokens[:self.max_length]\n",
    "            \n",
    "        return tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "dnpijQRFLnXV"
   },
   "outputs": [],
   "source": [
    "max_length = 500\n",
    "\n",
    "tokenizer = Tokenizer(max_length = max_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "VOl6UxZoLdg_"
   },
   "outputs": [],
   "source": [
    "def build_vocab_from_data(raw_data, tokenizer, **vocab_kwargs):\n",
    "    \n",
    "    token_freqs = collections.Counter()\n",
    "    \n",
    "    for label, text in raw_data:\n",
    "        tokens = tokenizer.tokenize(text)\n",
    "        token_freqs.update(tokens)\n",
    "                \n",
    "    vocab = torchtext.vocab.Vocab(token_freqs, **vocab_kwargs)\n",
    "    \n",
    "    return vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "eNLrpvt2Lgsr"
   },
   "outputs": [],
   "source": [
    "max_size = 25_000\n",
    "\n",
    "vocab = build_vocab_from_data(raw_train_data, tokenizer, max_size = max_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "AN1YQiYfLr0_"
   },
   "outputs": [],
   "source": [
    "def process_raw_data(raw_data, tokenizer, vocab):\n",
    "    \n",
    "    raw_data = [(label, text) for (label, text) in raw_data]\n",
    "\n",
    "    text_transform = sequential_transforms(tokenizer.tokenize,\n",
    "                                           vocab_func(vocab),\n",
    "                                           totensor(dtype=torch.long))\n",
    "    \n",
    "    label_transform = sequential_transforms(totensor(dtype=torch.long))\n",
    "\n",
    "    transforms = (label_transform, text_transform)\n",
    "\n",
    "    dataset = TextClassificationDataset(raw_data,\n",
    "                                        vocab,\n",
    "                                        transforms)\n",
    "    \n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "dlejEwWLMScW"
   },
   "outputs": [],
   "source": [
    "train_data = process_raw_data(raw_train_data, tokenizer, vocab)\n",
    "valid_data = process_raw_data(raw_valid_data, tokenizer, vocab)\n",
    "test_data = process_raw_data(raw_test_data, tokenizer, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "hggYldmOQahU"
   },
   "outputs": [],
   "source": [
    "class Collator:\n",
    "    def __init__(self, pad_idx):\n",
    "        \n",
    "        self.pad_idx = pad_idx\n",
    "        \n",
    "    def collate(self, batch):\n",
    "        \n",
    "        labels, text = zip(*batch)\n",
    "        \n",
    "        labels = torch.LongTensor(labels)\n",
    "        \n",
    "        lengths = torch.LongTensor([len(x) for x in text])\n",
    "\n",
    "        text = nn.utils.rnn.pad_sequence(text, padding_value = self.pad_idx)\n",
    "        \n",
    "        return labels, text, lengths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gw4LBXWAQiEC"
   },
   "outputs": [],
   "source": [
    "pad_token = '<pad>'\n",
    "pad_idx = vocab[pad_token]\n",
    "\n",
    "collator = Collator(pad_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "d0dP9wnZQjaU"
   },
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "\n",
    "train_iterator = torch.utils.data.DataLoader(train_data, \n",
    "                                             batch_size, \n",
    "                                             shuffle = True, \n",
    "                                             collate_fn = collator.collate)\n",
    "\n",
    "valid_iterator = torch.utils.data.DataLoader(valid_data, \n",
    "                                             batch_size, \n",
    "                                             shuffle = False, \n",
    "                                             collate_fn = collator.collate)\n",
    "\n",
    "test_iterator = torch.utils.data.DataLoader(test_data, \n",
    "                                            batch_size, \n",
    "                                            shuffle = False, \n",
    "                                            collate_fn = collator.collate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GRU(nn.Module):\n",
    "    def __init__(self, input_dim, emb_dim, hid_dim, output_dim, pad_idx):\n",
    "        super().__init__()\n",
    "\n",
    "        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = pad_idx)\n",
    "        self.gru = nn.GRUCell(emb_dim, hid_dim)\n",
    "        self.fc = nn.Linear(hid_dim, output_dim)\n",
    "\n",
    "    def forward(self, text, lengths):\n",
    "\n",
    "        # text = [seq len, batch size]\n",
    "        # lengths = [batch size]\n",
    "\n",
    "        embedded = self.embedding(text)\n",
    "\n",
    "        # embedded = [seq len, batch size, emb dim]\n",
    "\n",
    "        seq_len, batch_size, _ = embedded.shape\n",
    "        hid_dim = self.gru.hidden_size\n",
    "                \n",
    "        hidden = torch.zeros(batch_size, hid_dim).to(embedded.device)\n",
    "        \n",
    "        for i in range(seq_len):\n",
    "            x = embedded[i]\n",
    "            hidden = self.gru(x, hidden)\n",
    "        \n",
    "        prediction = self.fc(hidden)\n",
    "\n",
    "        # prediction = [batch size, output dim]\n",
    "\n",
    "        return prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GRU(nn.Module):\n",
    "    def __init__(self, input_dim, emb_dim, hid_dim, output_dim, pad_idx):\n",
    "        super().__init__()\n",
    "\n",
    "        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = pad_idx)\n",
    "        self.gru = nn.GRU(emb_dim, hid_dim)\n",
    "        self.fc = nn.Linear(hid_dim, output_dim)\n",
    "\n",
    "    def forward(self, text, lengths):\n",
    "\n",
    "        # text = [seq len, batch size]\n",
    "        # lengths = [batch size]\n",
    "\n",
    "        embedded = self.embedding(text)\n",
    "\n",
    "        # embedded = [seq len, batch size, emb dim]\n",
    "\n",
    "        output, hidden = self.gru(embedded)\n",
    "\n",
    "        # output = [seq_len, batch size, n directions * hid dim]\n",
    "        # hidden = [n layers * n directions, batch size, hid dim]\n",
    "\n",
    "        prediction = self.fc(hidden.squeeze(0))\n",
    "\n",
    "        # prediction = [batch size, output dim]\n",
    "\n",
    "        return prediction "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LGQ5JkfBQll0"
   },
   "outputs": [],
   "source": [
    "class GRU(nn.Module):\n",
    "    def __init__(self, input_dim, emb_dim, hid_dim, output_dim, pad_idx):\n",
    "        super().__init__()\n",
    "\n",
    "        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = pad_idx)\n",
    "        self.gru = nn.GRU(emb_dim, hid_dim)\n",
    "        self.fc = nn.Linear(hid_dim, output_dim)\n",
    "\n",
    "    def forward(self, text, lengths):\n",
    "\n",
    "        # text = [seq len, batch size]\n",
    "        # lengths = [batch size]\n",
    "\n",
    "        embedded = self.embedding(text)\n",
    "\n",
    "        # embedded = [seq len, batch size, emb dim]\n",
    "\n",
    "        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths, enforce_sorted = False)\n",
    "\n",
    "        packed_output, hidden = self.gru(packed_embedded)\n",
    "\n",
    "        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output)\n",
    "\n",
    "        # output = [seq_len, batch size, n directions * hid dim]\n",
    "        # hidden = [n layers * n directions, batch size, hid dim]\n",
    "\n",
    "        prediction = self.fc(hidden.squeeze(0))\n",
    "\n",
    "        # prediction = [batch size, output dim]\n",
    "\n",
    "        return prediction "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "mEb-ff-bQtKL"
   },
   "outputs": [],
   "source": [
    "input_dim = len(vocab)\n",
    "emb_dim = 100\n",
    "hid_dim = 256\n",
    "output_dim = 2\n",
    "\n",
    "model = GRU(input_dim, emb_dim, hid_dim, output_dim, pad_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WEwnyJT_Tm8q"
   },
   "outputs": [],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "id": "SJdVErKTTogS",
    "outputId": "aaf74c2e-2b9f-47df-a672-b809ffffd6e5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model has 2,775,658 trainable parameters\n"
     ]
    }
   ],
   "source": [
    "print(f'The model has {count_parameters(model):,} trainable parameters')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "name: embedding.weight, shape: torch.Size([25002, 100])\n",
      "name: gru.weight_ih_l0, shape: torch.Size([768, 100])\n",
      "name: gru.weight_hh_l0, shape: torch.Size([768, 256])\n",
      "name: gru.bias_ih_l0, shape: torch.Size([768])\n",
      "name: gru.bias_hh_l0, shape: torch.Size([768])\n",
      "name: fc.weight, shape: torch.Size([2, 256])\n",
      "name: fc.bias, shape: torch.Size([2])\n"
     ]
    }
   ],
   "source": [
    "for n, p in model.named_parameters():\n",
    "    print(f'name: {n}, shape: {p.shape}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def initialize_parameters(m):\n",
    "    if isinstance(m, nn.Embedding):\n",
    "        nn.init.uniform_(m.weight, -0.05, 0.05)\n",
    "    elif isinstance(m, nn.GRU):\n",
    "        for n, p in m.named_parameters():\n",
    "            if 'weight_ih' in n:\n",
    "                r, z, n = p.chunk(3)\n",
    "                nn.init.xavier_uniform_(r)\n",
    "                nn.init.xavier_uniform_(z)\n",
    "                nn.init.xavier_uniform_(n)\n",
    "            elif 'weight_hh' in n:\n",
    "                r, z, n = p.chunk(3)\n",
    "                nn.init.orthogonal_(r)\n",
    "                nn.init.orthogonal_(z)\n",
    "                nn.init.orthogonal_(n)\n",
    "            elif 'bias' in n:\n",
    "                r, z, n = p.chunk(3)\n",
    "                nn.init.zeros_(r)\n",
    "                nn.init.zeros_(z)\n",
    "                nn.init.zeros_(n)\n",
    "    elif isinstance(m, nn.Linear):\n",
    "        nn.init.xavier_uniform_(m.weight)\n",
    "        nn.init.zeros_(m.bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GRU(\n",
       "  (embedding): Embedding(25002, 100, padding_idx=1)\n",
       "  (gru): GRU(100, 256)\n",
       "  (fc): Linear(in_features=256, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.apply(initialize_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HE9cEN3XTpf7"
   },
   "outputs": [],
   "source": [
    "glove = torchtext.experimental.vectors.GloVe(name = '6B',\n",
    "                                             dim = emb_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "AyI08bfvTrCV"
   },
   "outputs": [],
   "source": [
    "def get_pretrained_embedding(initial_embedding, pretrained_vectors, vocab, unk_token):\n",
    "    \n",
    "    pretrained_embedding = torch.FloatTensor(initial_embedding.weight.clone()).detach()    \n",
    "    pretrained_vocab = pretrained_vectors.vectors.get_stoi()\n",
    "    \n",
    "    unk_tokens = []\n",
    "    \n",
    "    for idx, token in enumerate(vocab.itos):\n",
    "        if token in pretrained_vocab:\n",
    "            pretrained_vector = pretrained_vectors[token]\n",
    "            pretrained_embedding[idx] = pretrained_vector\n",
    "        else:\n",
    "            unk_tokens.append(token)\n",
    "        \n",
    "    return pretrained_embedding, unk_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "GPMcsd6HTtoC"
   },
   "outputs": [],
   "source": [
    "unk_token = '<unk>'\n",
    "\n",
    "pretrained_embedding, unk_tokens = get_pretrained_embedding(model.embedding, glove, vocab, unk_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 139
    },
    "colab_type": "code",
    "id": "LhlnYb2ZTvPr",
    "outputId": "8d56d0e2-6af1-40fe-ea1e-9ec7a42d8b15"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0098,  0.0150, -0.0099,  ...,  0.0211, -0.0092,  0.0027],\n",
       "        [ 0.0347,  0.0276,  0.0468,  ..., -0.0315, -0.0472, -0.0326],\n",
       "        [-0.0382, -0.2449,  0.7281,  ..., -0.1459,  0.8278,  0.2706],\n",
       "        ...,\n",
       "        [-0.2925,  0.1087,  0.7920,  ..., -0.3641,  0.1822, -0.4104],\n",
       "        [-0.7250,  0.7545,  0.1637,  ..., -0.0144, -0.1761,  0.3418],\n",
       "        [ 1.1753,  0.0460, -0.3542,  ...,  0.4510,  0.0485, -0.4015]])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.embedding.weight.data.copy_(pretrained_embedding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.embedding.weight.data[pad_idx] = torch.zeros(emb_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0098,  0.0150, -0.0099,  ...,  0.0211, -0.0092,  0.0027],\n",
       "        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
       "        [-0.0382, -0.2449,  0.7281,  ..., -0.1459,  0.8278,  0.2706],\n",
       "        ...,\n",
       "        [-0.2925,  0.1087,  0.7920,  ..., -0.3641,  0.1822, -0.4104],\n",
       "        [-0.7250,  0.7545,  0.1637,  ..., -0.0144, -0.1761,  0.3418],\n",
       "        [ 1.1753,  0.0460, -0.3542,  ...,  0.4510,  0.0485, -0.4015]])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.embedding.weight.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Sji9nWvaTxcp"
   },
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "a4Q-afN8Tyqr"
   },
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "PjZOAABMT0-T"
   },
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "6cYt2pfoT3TD"
   },
   "outputs": [],
   "source": [
    "model = model.to(device)\n",
    "criterion = criterion.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "SSdhLxTJT4mn"
   },
   "outputs": [],
   "source": [
    "def calculate_accuracy(predictions, labels):\n",
    "    top_predictions = predictions.argmax(1, keepdim = True)\n",
    "    correct = top_predictions.eq(labels.view_as(top_predictions)).sum()\n",
    "    accuracy = correct.float() / labels.shape[0]\n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "EoJT5j-1T54w"
   },
   "outputs": [],
   "source": [
    "def train(model, iterator, optimizer, criterion, device):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    for labels, text, lengths in iterator:\n",
    "        \n",
    "        labels = labels.to(device)\n",
    "        text = text.to(device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        predictions = model(text, lengths)\n",
    "        \n",
    "        loss = criterion(predictions, labels)\n",
    "        \n",
    "        acc = calculate_accuracy(predictions, labels)\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        epoch_loss += loss.item()\n",
    "        epoch_acc += acc.item()\n",
    "\n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "UBh7g1cnUBMG"
   },
   "outputs": [],
   "source": [
    "def evaluate(model, iterator, criterion, device):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.eval()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "    \n",
    "        for labels, text, lengths in iterator:\n",
    "\n",
    "            labels = labels.to(device)\n",
    "            text = text.to(device)\n",
    "            \n",
    "            predictions = model(text, lengths)\n",
    "            \n",
    "            loss = criterion(predictions, labels)\n",
    "            \n",
    "            acc = calculate_accuracy(predictions, labels)\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "            epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jSMtdoeSUDAH"
   },
   "outputs": [],
   "source": [
    "def epoch_time(start_time, end_time):\n",
    "    elapsed_time = end_time - start_time\n",
    "    elapsed_mins = int(elapsed_time / 60)\n",
    "    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
    "    return elapsed_mins, elapsed_secs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 537
    },
    "colab_type": "code",
    "id": "lG-dJsjFUF8x",
    "outputId": "c434d13f-4efa-4a7c-c346-5e886db0405d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 01 | Epoch Time: 0m 7s\n",
      "\tTrain Loss: 0.654 | Train Acc: 60.73%\n",
      "\t Val. Loss: 0.584 |  Val. Acc: 68.87%\n",
      "Epoch: 02 | Epoch Time: 0m 7s\n",
      "\tTrain Loss: 0.423 | Train Acc: 80.73%\n",
      "\t Val. Loss: 0.332 |  Val. Acc: 86.04%\n",
      "Epoch: 03 | Epoch Time: 0m 7s\n",
      "\tTrain Loss: 0.252 | Train Acc: 90.15%\n",
      "\t Val. Loss: 0.285 |  Val. Acc: 88.63%\n",
      "Epoch: 04 | Epoch Time: 0m 8s\n",
      "\tTrain Loss: 0.186 | Train Acc: 93.05%\n",
      "\t Val. Loss: 0.286 |  Val. Acc: 89.40%\n",
      "Epoch: 05 | Epoch Time: 0m 7s\n",
      "\tTrain Loss: 0.116 | Train Acc: 95.85%\n",
      "\t Val. Loss: 0.307 |  Val. Acc: 89.56%\n",
      "Epoch: 06 | Epoch Time: 0m 7s\n",
      "\tTrain Loss: 0.065 | Train Acc: 97.90%\n",
      "\t Val. Loss: 0.354 |  Val. Acc: 89.64%\n",
      "Epoch: 07 | Epoch Time: 0m 8s\n",
      "\tTrain Loss: 0.042 | Train Acc: 98.74%\n",
      "\t Val. Loss: 0.403 |  Val. Acc: 89.35%\n",
      "Epoch: 08 | Epoch Time: 0m 8s\n",
      "\tTrain Loss: 0.020 | Train Acc: 99.47%\n",
      "\t Val. Loss: 0.408 |  Val. Acc: 89.35%\n",
      "Epoch: 09 | Epoch Time: 0m 7s\n",
      "\tTrain Loss: 0.010 | Train Acc: 99.81%\n",
      "\t Val. Loss: 0.505 |  Val. Acc: 88.53%\n",
      "Epoch: 10 | Epoch Time: 0m 7s\n",
      "\tTrain Loss: 0.007 | Train Acc: 99.85%\n",
      "\t Val. Loss: 0.657 |  Val. Acc: 88.27%\n"
     ]
    }
   ],
   "source": [
    "n_epochs = 10\n",
    "\n",
    "best_valid_loss = float('inf')\n",
    "\n",
    "for epoch in range(n_epochs):\n",
    "\n",
    "    start_time = time.monotonic()\n",
    "    \n",
    "    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)\n",
    "    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)\n",
    "    \n",
    "    end_time = time.monotonic()\n",
    "\n",
    "    epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
    "    \n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        torch.save(model.state_dict(), 'gru-model.pt')\n",
    "    \n",
    "    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
    "    print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
    "    print(f'\\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "id": "PH7-0f6nUKRb",
    "outputId": "faf1e6dd-c99e-4fda-c6f8-435a08ca0073"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 0.290 | Test Acc: 87.93%\n"
     ]
    }
   ],
   "source": [
    "model.load_state_dict(torch.load('gru-model.pt'))\n",
    "\n",
    "test_loss, test_acc = evaluate(model, test_iterator, criterion, device)\n",
    "\n",
    "print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "rnWNSo8kdcl_"
   },
   "outputs": [],
   "source": [
    "def predict_sentiment(tokenizer, vocab, model, device, sentence):\n",
    "    model.eval()\n",
    "    tokens = tokenizer.tokenize(sentence)\n",
    "    length = torch.LongTensor([len(tokens)]).to(device)\n",
    "    indexes = [vocab.stoi[token] for token in tokens]\n",
    "    tensor = torch.LongTensor(indexes).unsqueeze(-1).to(device)\n",
    "    prediction = model(tensor, length)\n",
    "    probabilities = nn.functional.softmax(prediction, dim = -1)\n",
    "    pos_probability = probabilities.squeeze()[-1].item()\n",
    "    return pos_probability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "id": "hb7bC-aEeC1q",
    "outputId": "059cccd1-efb4-404c-81f9-606983c23b33"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.06520231813192368"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentence = 'the absolute worst movie of all time.'\n",
    "\n",
    "predict_sentiment(tokenizer, vocab, model, device, sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "id": "APEVZ3D4eEVw",
    "outputId": "0d188e29-6e4e-4183-c7aa-467ea8f1afe6"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8539475798606873"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentence = 'one of the greatest films i have ever seen in my life.'\n",
    "\n",
    "predict_sentiment(tokenizer, vocab, model, device, sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "id": "X7GMey_jebjg",
    "outputId": "04ca4196-51f0-4661-ffe4-8f4dd199baf4"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.15590433776378632"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentence = \"i thought it was going to be one of the greatest films i have ever seen in my life, \\\n",
    "but it was actually the absolute worst movie of all time.\"\n",
    "\n",
    "predict_sentiment(tokenizer, vocab, model, device, sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "id": "kOoESlQSxYx2",
    "outputId": "e5826bef-5f9c-41f6-9eb0-795318280045"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3470574617385864"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentence = \"i thought it was going to be the absolute worst movie of all time, \\\n",
    "but it was actually one of the greatest films i have ever seen in my life.\"\n",
    "\n",
    "predict_sentiment(tokenizer, vocab, model, device, sentence)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "machine_shape": "hm",
   "name": "2_rnn_gru.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
